import ipaddress
import logging
import pickle
import socket
import threading
import time
from typing import Tuple, Sequence, Any

import rlp
from eth_keys import datatypes
from eth_keys.datatypes import PublicKey
from eth_typing import Hash32
from eth_utils import encode_hex

from protocol_utils import EXPIRATION, pack_v4, PROTOCOL_VERSION, enc_port, CMD_PING, MAC_SIZE, formatted_date, \
    int_to_big_endian, KADEMLIA_PUBLIC_KEY_SIZE, CMD_PONG, CMD_FIND_NODE, DISCOVERY_DATAGRAM_BUFFER_SIZE, unpack_v4, \
    CMD_NEIGHBOURS, big_endian_to_int, remote_to_str, extract_nodes_from_payload
from .config import FROM_UDP_PORT, FROM_TCP_PORT, COMPLETED_UP_POOL, ASKED_POOL, PING_WAITING_POOL, \
    NEIGHBORS_WAITING_POOL, ADDED_POOL, PING_PENDING_POOL, NEIGHBORS_PENDING_POOL, UP_POOL
from .data_management import remote_to_key
from .data_management.redis import redis_init


class Connection(object):
    def __init__(self, from_addr=("127.0.0.1", FROM_UDP_PORT, FROM_TCP_PORT),
                 **conf):
        self.from_addr: Tuple[str, int, int] = from_addr
        self.logger = conf.get('logger', logging)
        self.expiration: int = conf.get('expiration', EXPIRATION)
        self.private_key: datatypes.PrivateKey = conf.get('private_key', None)
        self.socket: socket = None
        self.serve_thread = None
        self.redis_db = conf.get('redis_db', redis_init())
        self.collect_neighbors = True
        """print("public")
        print(self.private_key.public_key)
        print(self.private_key.public_key.to_hex())
        print(self.private_key.public_key.to_bytes())
        print(self.private_key.public_key.to_address())
        print(self.private_key.public_key.to_canonical_address())
        print(self.private_key.public_key.to_checksum_address())
        print(self.private_key.public_key.to_compressed_bytes())
        print("private")
        print(self.private_key)"""
        # sys.exit(0)

    def open(self):
        self.socket = socket.socket(socket.AF_INET,
                                    socket.SOCK_DGRAM)
        self.socket.settimeout(self.expiration)
        self.socket.bind(('', self.from_addr[1]))
        self.logger.info("Socket opened and binded on port %d", self.from_addr[1])
        self.serve_thread = threading.Thread(target=self.serve_forever)
        # self.serve_thread.daemon = True
        self.serve_thread.start()

    def close(self):
        if self.socket:
            sock = self.socket
            self.socket = None
            try:
                sock.shutdown(socket.SHUT_RDWR)
                self.logger.info("Socket closed")
            except socket.error:
                pass
            finally:
                sock.close()

    def send(self, remote: Tuple[str, int, int], cmd_id: int, payload: Sequence[Any]):
        """
            Pack the given payload using the given msg type and send it over our socket.

            If we get an OSError from our socket when attempting to send it, that will be logged
            and the message will be lost.
            """
        message = pack_v4(cmd_id, payload, self.private_key)
        try:
            self.socket.sendto(message, (remote[0], remote[1]))
        except Exception as e:
            if self.socket:
                self.logger.error("Unexpected error when sending msg to %s:", remote)
                self.logger.error(e)
        return message

    def _get_msg_expiration(self):
        return rlp.sedes.big_endian_int.serialize(int(time.time() + self.expiration))

    def _is_msg_expired(self, rlp_expiration: bytes) -> bool:
        expiration = rlp.sedes.big_endian_int.deserialize(rlp_expiration)
        if time.time() > expiration:
            self.logger.warning('Received message already expired')
            return True
        return False

    def _find_node(self, remote: Tuple[str, int, int], target_key: bytes) -> None:
        remote_key = remote_to_key(remote)
        if not self.redis_db.sismember(COMPLETED_UP_POOL, remote_key) and self.collect_neighbors:
            # If we haven't get a FindNode response yet from node
            self.redis_db.sadd(ASKED_POOL, remote_key)
            self.send_find_node_v4(remote, target_key)
            # self.send_find_node_v4(remote, random_lookup())

    def ping(self, remote: Tuple[str, int, int]):
        version = rlp.sedes.big_endian_int.serialize(PROTOCOL_VERSION)
        expiration = self._get_msg_expiration()
        this_address = [ipaddress.ip_address(self.from_addr[0]).packed, enc_port(self.from_addr[1]),
                        enc_port(self.from_addr[2])]
        node_address = [ipaddress.ip_address(remote[0]).packed, enc_port(remote[1]), enc_port(remote[2])]
        payload = (version, this_address, node_address, expiration)
        message = self.send(remote, CMD_PING, payload)
        # Return the msg hash, which is used as a token to identify pongs.
        token = Hash32(message[:MAC_SIZE])

        remote_key = remote_to_key(remote)
        node = pickle.loads(self.redis_db.get(remote_key))
        node['Ping'] += 1
        node['Ping ' + str(node['Ping'])] = formatted_date()
        self.redis_db.set(remote_key, pickle.dumps(node))
        self.redis_db.set(pickle.dumps(encode_hex(token)), remote_key)
        self.redis_db.sadd(PING_WAITING_POOL, remote_key)
        self.logger.debug('>>> ping (v4) %s (token == %s)', remote, encode_hex(token))

        return token

    def find_nodes_fix(self, remote: Tuple[str, int, int]):

        remote_key = remote_to_key(remote)
        node = pickle.loads(self.redis_db.get(remote_key))

        if "Node ID" in node:
            self.logger.debug('>>> find_node_fix nodeid : %s -- %s', int(node["Node ID"], base=16), node["Node ID"])
            remote_pubkey = int_to_big_endian(int(node["Node ID"], base=16)).rjust(KADEMLIA_PUBLIC_KEY_SIZE // 8,
                                                                                   b"\x00")
            node['Ping_ng'] += 1
            node['Ping_ng ' + str(node['Ping_ng'])] = formatted_date()
            self.redis_db.set(remote_key, pickle.dumps(node))
            self.redis_db.sadd(NEIGHBORS_WAITING_POOL, remote_key)
            self.send_find_node_v4(remote, remote_pubkey)
        else:
            print(f"\n{node}\n")

        # threading.Timer(0.1, self.send_find_node_v4, args=[remote, remote_pubkey]).start()

    def send_pong_v4(self, remote: Tuple[str, int, int], token: Hash32) -> None:
        expiration = self._get_msg_expiration()
        self.logger.debug('>>> pong %s', remote)
        node_address = [ipaddress.ip_address(remote[0]).packed, enc_port(remote[1]),
                        enc_port(remote[2])]
        payload = (node_address, token, expiration)
        self.send(remote, CMD_PONG, payload)

    def send_find_node_v4(self, remote: Tuple[str, int, int], target_key: bytes) -> None:
        if len(target_key) != KADEMLIA_PUBLIC_KEY_SIZE // 8:
            raise ValueError(f"Invalid FIND_NODE target ({target_key!r}). Length is not 64")
        expiration = self._get_msg_expiration()
        self.logger.debug('>>> find_node %s to %s', target_key, remote)
        self.send(remote, CMD_FIND_NODE, (target_key, expiration))

    def serve_forever(self) -> None:
        self.logger.info("Serving...")
        while self.socket:
            try:
                datagram, (ip_address, port) = self.socket.recvfrom(DISCOVERY_DATAGRAM_BUFFER_SIZE)
                threading.Thread(target=self.consume_datagram, args=[(ip_address, port, port), datagram]).run()
            except TimeoutError:
                self.close()
                exit(0)
            except ConnectionResetError as e:
                self.logger.warning(e)
            except Exception as e:
                if self.socket:
                    self.logger.error(e)

    def consume_datagram(self, address: Tuple[str, int, int], datagram):
        self.logger.trace("Received datagram from %s", address)
        self.handle_msg(address, datagram)

    def handle_msg(self, address: Tuple[str, int, int], message: bytes) -> None:
        try:
            remote_pubkey, cmd_id, payload, message_hash = unpack_v4(message)
        except SyntaxError as e:
            self.logger.error('Error unpacking message (%s) from %s: %s', message, address, e)
            return

        remote_key = remote_to_key(address)
        if self.redis_db.exists(remote_key):
            node = pickle.loads(self.redis_db.get(remote_key))
            if node.get('Node ID'):
                node['Seen node IDs'].add(node['Node ID'])
            node['Node ID'] = remote_pubkey.to_hex()
            self.redis_db.set(remote_key, pickle.dumps(node))

        self.logger.trace("Received cmd %s from %s with payload: %s", cmd_id, address, payload)
        if cmd_id == CMD_PING:
            return self.recv_ping_v4(address, remote_pubkey, payload, message_hash)
        elif cmd_id == CMD_PONG:
            return self.recv_pong_v4(address, remote_pubkey, payload, message_hash)
        elif cmd_id == CMD_FIND_NODE:
            return
        elif cmd_id == CMD_NEIGHBOURS:
            return self.recv_neighbours_v4(address, payload, message_hash)
        else:
            raise ValueError(f"Unknown command id: {cmd_id}")

    def recv_ping_v4(
            self, remote: Tuple[str, int, int], remote_pubkey: PublicKey, payload: Sequence[Any],
            message_hash: Hash32) -> None:
        """Process a received ping packet.

        A ping packet may come any time, unrequested, or may be prompted by us bond()ing with a
        new node. In the former case we'll just reply with a pong, whereas in the latter we'll
        also send an empty msg on the appropriate channel from ping_channels, to notify any
        coroutine waiting for that ping.

        Also, if we have no valid bond with the given remote, we'll trigger one in the background.
        """
        # The ping payload should have at least 4 elements: [version, from, to, expiration], with
        # an optional 5th element for the node's ENR sequence number.
        if len(payload) < 4:
            self.logger.warning('Ignoring PING msg with invalid payload: %s', payload)
            return
        elif len(payload) == 4:
            _, _, _, expiration = payload[:4]
            enr_seq = None
        else:
            _, _, _, expiration, enr_seq = payload[:5]
            enr_seq = big_endian_to_int(enr_seq)
        self.logger.debug('<<< ping(v4) from %s, enr_seq=%s', remote, enr_seq)
        self._is_msg_expired(expiration)

        remote_key = remote_to_key(remote)
        if not self.redis_db.sismember(ADDED_POOL, remote_key):
            # If a node sent a ping before we found it
            self.redis_db.sadd(ADDED_POOL, remote_key)
            self.redis_db.set(remote_key, pickle.dumps({'IP address': remote[0],
                                                        'UDP port': remote[1],
                                                        'TCP port': remote[2],
                                                        'Pending': formatted_date(),
                                                        'Ping': 0,
                                                        'Ping_ng': 0,
                                                        'Parents': {remote_to_str(remote)},
                                                        'Node ID': remote_pubkey.to_hex(),
                                                        'Seen node IDs': {remote_pubkey.to_hex()}}))
            self.redis_db.sadd(PING_PENDING_POOL, remote_key)
        else:
            self.send_pong_v4(remote, message_hash)
            if not self.redis_db.sismember(ASKED_POOL, remote_key):
                self.redis_db.sadd(ASKED_POOL, remote_key)
                self.redis_db.sadd(NEIGHBORS_PENDING_POOL, remote_key)
            # self._find_node(remote, remote_pubkey.to_bytes())

    def recv_pong_v4(self, remote: Tuple[str, int, int], remote_pubkey: PublicKey, payload: Sequence[Any],
                     _: Hash32) -> None:
        # The pong payload should have at least 3 elements: to, token, expiration
        if len(payload) < 3:
            self.logger.warning('Ignoring PONG msg with invalid payload: %s', payload)
            return
        elif len(payload) == 3:
            _, token, expiration = payload[:3]
            enr_seq = None
        else:
            _, token, expiration, enr_seq = payload[:4]
            enr_seq = big_endian_to_int(enr_seq)
        self._is_msg_expired(expiration)

        remote_key: bytes = self.redis_db.get(pickle.dumps(encode_hex(token)))
        self.redis_db.srem(PING_WAITING_POOL, remote_key)
        self.redis_db.sadd(UP_POOL, remote_key)
        node = pickle.loads(self.redis_db.get(remote_key))
        node['Up'] = formatted_date()
        self.redis_db.set(remote_key, pickle.dumps(node))
        self.logger.debug('<<< pong (v4) from %s (token == %s)', remote, encode_hex(token))

        # Send FindNode if the node haven't send a Ping back

        if not self.redis_db.sismember(ASKED_POOL, remote_key):
            self.redis_db.sadd(ASKED_POOL, remote_key)
            self.redis_db.sadd(NEIGHBORS_PENDING_POOL, remote_key)
        # threading.Timer(0.05, self._find_node, args=[remote, remote_pubkey.to_bytes()]).start()

    def recv_neighbours_v4(self, remote: Tuple[str, int, int], payload: Sequence[Any], _: Hash32) -> None:
        # The neighbours payload should have 2 elements: nodes, expiration
        if len(payload) < 2:
            self.logger.warning('Ignoring NEIGHBOURS msg with invalid payload: %s', payload)
            return
        nodes, expiration = payload[:2]
        self._is_msg_expired(expiration)
        try:
            neighbours = extract_nodes_from_payload(remote, nodes, self.logger)
        except ValueError:
            self.logger.warning("Malformed NEIGHBOURS packet from %s: %s", remote, nodes)
            return

        remote_key = remote_to_key(remote)

        self.redis_db.srem(NEIGHBORS_WAITING_POOL, remote_key)
        self.redis_db.sadd(COMPLETED_UP_POOL, remote_key)

        node = pickle.loads(self.redis_db.get(remote_key))
        node['Up'] = formatted_date()
        self.redis_db.set(remote_key, pickle.dumps(node))

        for neighbour in neighbours:
            neighbour_key: bytes = remote_to_key(neighbour[0])
            if not self.redis_db.sismember(ADDED_POOL, neighbour_key):
                self.redis_db.sadd(ADDED_POOL, neighbour_key)
                self.redis_db.set(neighbour_key, pickle.dumps({'IP address': neighbour[0][0],
                                                               'UDP port': neighbour[0][1],
                                                               'TCP port': neighbour[0][2],
                                                               'Pending': formatted_date(),
                                                               'Ping': 0,
                                                               'Ping_ng': 0,
                                                               'Parents': {remote_to_str(remote)},
                                                               'Node ID': neighbour[1].to_hex(),
                                                               'Seen node IDs': {neighbour[1].to_hex()}}))
                self.redis_db.sadd(PING_PENDING_POOL, neighbour_key)
            else:
                node = pickle.loads(self.redis_db.get(neighbour_key))
                if node.get('Parents'):
                    node['Parents'].add(remote_to_str(remote))
                node['Seen node IDs'].add(neighbour[1].to_hex())
                self.redis_db.set(neighbour_key, pickle.dumps(node))
        self.logger.debug('<<< %s neighbours from %s: %s', len(neighbours), remote, neighbours)
